import torch
from torch import nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import dgl
from dgl.nn.pytorch import GraphConv
from dgl.nn.pytorch import ChebConv
import networkx as nx

def create_dgl_graph(G, features, directed=False):
    '''
        Creates DGL graph representation
        for line graph of G (edges as nodes).
        Also returns an edge map: G edges to
        DGL graph nodes.
    '''
    G_line = nx.line_graph(G)

    edge_map = {}
    for e in G_line.nodes():
        edge_map[e] = len(edge_map)

    if directed:
        source = np.zeros(G_line.number_of_nodes() +  G_line.number_of_edges())
        dest = np.zeros(G_line.number_of_nodes() +  G_line.number_of_edges())
    else:
        source = np.zeros(G_line.number_of_nodes() + 2 * G_line.number_of_edges())
        dest = np.zeros(G_line.number_of_nodes() + 2 * G_line.number_of_edges())

    idx = 0
    for e in G_line.edges():
        m_i = edge_map[e[0]]
        m_j = edge_map[e[1]]

        source[idx] = m_i
        dest[idx] = m_j
        
        idx = idx+1

        if directed is False:
            source[idx] = m_j
            dest[idx] = m_i

            idx = idx+1
    
    for v in G_line.nodes():
        m_i = edge_map[v]
        source[idx] = m_i
        dest[idx] = m_i

        idx = idx+1

    source = torch.tensor(source, dtype=int)
    dest = torch.tensor(dest, dtype=int)
    G_dgl = dgl.graph((source, dest))
    n_features = features[list(features.keys())[0]].shape[0]

    feat = torch.zeros((G.number_of_edges(), n_features))

    G_dgl.ndata['feat'] = feat

    for e in G_line.nodes():
        i = edge_map[e]
        G_dgl.ndata['feat'][i] = torch.FloatTensor(features[e])

    use_cuda = torch.cuda.is_available()

    if use_cuda:
        return G_dgl.to(torch.device('cuda:0')), edge_map
    else:
        return G_dgl, edge_map

def get_feat_ids(G, flows, edge_map):
    '''
        Extracts feature representation for GNNLearnFlow
    '''
    feat_ids = []
    
    for e in G.edges():
        if e in flows:
            idx = edge_map[e]
            feat_ids.append(idx)
    
    use_cuda = torch.cuda.is_available()
    
    if use_cuda:
        return torch.tensor(feat_ids, device='cuda:0')
    else:
        return torch.tensor(feat_ids)

class GCN(nn.Module):
    '''
        2-layer graph convolutional network.
    '''
    def __init__(self, n_features, n_hidden, n_iter, lr, lamb_max, early_stop=10, output_activation=torch.sigmoid):
        super(GCN, self).__init__()
       
        #self.conv1 = GraphConv(n_features, n_hidden)
        #self.conv2 = GraphConv(n_hidden, 1)
        
        self.conv1 = ChebConv(n_features, n_hidden, 2)
        self.conv2 = ChebConv(n_hidden, 1, 2)
        
        self.n_iter = n_iter
        self.early_stop = early_stop
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.output_activation = output_activation
        self.lamb_max = lamb_max

        self.use_cuda = torch.cuda.is_available()
        
        if self.use_cuda:
            self.cuda()
      
    def forward(self, g, inputs):
        h = self.conv1(g, inputs, self.lamb_max)
        h = torch.relu(h)
        h = self.output_activation(self.conv2(g, h, self.lamb_max))
        
        return h
    
    def train(self, dgl_G, edge_map, xs_train, ys_train, xs_valid, ys_valid, verbose=False):
        loss_func = nn.MSELoss()
        valid_losses = []
        for epoch in range(self.n_iter):
            self.optimizer.zero_grad()
            outputs = (self.forward(dgl_G, dgl_G.ndata['feat']).T)[0]
            train_loss = loss_func(outputs[xs_train], ys_train)
                       
            valid_loss = loss_func(outputs[xs_valid], ys_valid)
            train_loss.backward()
            self.optimizer.step()
            
            valid_losses.append(valid_loss.item())
            
            if epoch % 1000 == 0 and verbose is True:
                print("epoch: ", epoch, " train loss = ", train_loss.item(), " valid loss = ", valid_loss.item())
                
            if epoch > self.early_stop and valid_losses[-1] > np.mean(valid_losses[-(self.early_stop+1):-1]):
                if verbose is True:
                    print("Early stopping...")
                break
                
